import cv2
import numpy as np

from model.bbox_transform import bbox_transform_inv
from model.config import cfg
from utils import Timer
from utils import im_list_to_blob
from utils import nms


def _get_image_blob(im):
    """Converts an image into a network input.
    Arguments:
      im (ndarray): a color image in BGR order
    Returns:
      blob (ndarray): a data blob holding an image pyramid
      im_scale_factors (list): list of image scales (relative to im) used
        in the image pyramid
    """
    im_orig = im.astype(np.float32, copy=True)
    im_orig -= cfg.PIXEL_MEANS

    im_shape = im_orig.shape
    im_size_min = np.min(im_shape[0:2])
    im_size_max = np.max(im_shape[0:2])

    processed_ims = []
    im_scale_factors = []

    for target_size in cfg.TEST.SCALES:
        im_scale = float(target_size) / float(im_size_min)
        # Prevent the biggest axis from being more than MAX_SIZE
        if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
            im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
        im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
                        interpolation=cv2.INTER_LINEAR)
        im_scale_factors.append(im_scale)
        processed_ims.append(im)

    # Create a blob to hold the input images
    blob = im_list_to_blob(processed_ims)

    return blob, np.array(im_scale_factors)


def _get_blobs(im):
    """Convert an image and RoIs within that image into network inputs."""
    blobs = {}
    blobs['data'], im_scale_factors = _get_image_blob(im)

    return blobs, im_scale_factors


def _clip_boxes(boxes, im_shape):
    """Clip boxes to image boundaries."""
    # x1 >= 0
    boxes[:, 0::4] = np.maximum(boxes[:, 0::4], 0)
    # y1 >= 0
    boxes[:, 1::4] = np.maximum(boxes[:, 1::4], 0)
    # x2 < im_shape[1]
    boxes[:, 2::4] = np.minimum(boxes[:, 2::4], im_shape[1] - 1)
    # y2 < im_shape[0]
    boxes[:, 3::4] = np.minimum(boxes[:, 3::4], im_shape[0] - 1)
    return boxes


def _rescale_boxes(boxes, inds, scales):
    """Rescale boxes according to image rescaling."""
    for i in range(boxes.shape[0]):
        boxes[i, :] = boxes[i, :] / scales[int(inds[i])]

    return boxes


def im_detect(sess, net, im):
    blobs, im_scales = _get_blobs(im)
    assert len(im_scales) == 1, "Only single-image batch implemented"

    im_blob = blobs['data']
    # seems to have height, width, and image scales
    # still not sure about the scale, maybe full image it is 1.
    blobs['im_info'] = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)

    _, scores, bbox_pred, rois, feats = net.test_image(sess, blobs['data'], blobs['im_info'])

    boxes = rois[:, 1:5] / im_scales[0]

    if cfg.TEST.BBOX_REG:
        # Apply bounding-box regression deltas
        box_deltas = bbox_pred
        pred_boxes = bbox_transform_inv(boxes, box_deltas)
        pred_boxes = _clip_boxes(pred_boxes, im.shape)
    else:
        # Simply repeat the boxes, once for each class
        pred_boxes = np.tile(boxes, (1, scores.shape[1]))

    return scores, pred_boxes, feats


def apply_nms(all_boxes, thresh):
    """Apply non-maximum suppression to all predicted boxes output by the
    test_net method.
    """
    num_classes = len(all_boxes)
    num_images = len(all_boxes[0])
    nms_boxes = [[[] for _ in range(num_images)] for _ in range(num_classes)]
    for cls_ind in range(num_classes):
        for im_ind in range(num_images):
            dets = all_boxes[cls_ind][im_ind]
            if dets == []:
                continue

            x1 = dets[:, 0]
            y1 = dets[:, 1]
            x2 = dets[:, 2]
            y2 = dets[:, 3]
            scores = dets[:, 4]
            inds = np.where((x2 > x1) & (y2 > y1) & (scores > cfg.TEST.DET_THRESHOLD))[0]
            dets = dets[inds, :]
            if dets == []:
                continue

            keep = nms(dets, thresh)
            if len(keep) == 0:
                continue
            nms_boxes[cls_ind][im_ind] = dets[keep, :].copy()
    return nms_boxes


classes = (
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
    'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
    'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
    'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
    'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
    'hair drier', 'toothbrush')


def get_boxes(sess, net, img, all_feats, all_boxes):
    boxes = {}
    for cls_ind, cls in enumerate(classes):
        if cls == '__background__':
            continue
        dets = all_boxes[cls_ind]
        feats = all_feats[cls_ind]
        if dets == []:
            continue
        for k in range(dets.shape[0]):
            det = dets[k]
            box = dict()
            box['lt'] = [int(det[0]), int(det[1])]
            box['rb'] = [int(det[2]), int(det[3])]
            box['f'] = feats[k]
            box['cl'] = cls
            box['score'] = det[4]
            box_width = box['rb'][0] - box['lt'][0]
            box_height = box['rb'][1] - box['lt'][1]
            min_size = min(box_width, box_height)
            if min_size > 10:
                if not cls in boxes:
                    boxes[cls] = []
                class_boxes = boxes[cls]
                # boxes.append(box)
                class_boxes.append(box)
    return boxes


def extract_regions_and_feats(sess, net, im, max_per_image=10, max_per_class=3, thresh=0.1):
    """Extract regions and features with respect to each region"""
    # all detections are collected into:
    #  all_boxes[cls][image] = N x 5 array of detections in
    #  (x1, y1, x2, y2, score)
    # timecrs
    _t = {'im_detect': Timer(), 'misc': Timer()}
    if type(im) == str:
        im = cv2.imread(im)

    _t['im_detect'].tic()
    scores, boxes, feats = im_detect(sess, net, im)
    _t['im_detect'].toc()

    _t['misc'].tic()

    all_boxes = [[] for _ in range(81)]
    all_feats = [[] for _ in range(81)]
    # skip j = 0, because it's the background class
    for j in range(1, 81):
        image_thresh = np.sort(scores[:, j])[-max_per_image]
        th = thresh if thresh > image_thresh else image_thresh
        inds = np.where(scores[:, j] > th)[0]
        cls_scores = scores[inds, j]
        cls_boxes = boxes[inds, j * 4:(j + 1) * 4]
        cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32, copy=False)
        keep = nms(cls_dets, cfg.TEST.NMS)
        cls_dets = cls_dets[keep, :]
        feats_part = feats[keep, :]
        all_boxes[j] = cls_dets
        all_feats[j] = feats_part

    # Limit to max_per_image detections *over all classes*
    if max_per_image > 0:
        image_scores = np.hstack([all_boxes[j][:, -1]
                                  for j in range(1, 81)])
        if len(image_scores) > max_per_image:
            image_thresh = np.sort(image_scores)[-max_per_image]
            for j in range(1, 81):
                keep = np.where(all_boxes[j][:, -1] >= image_thresh)[0]
                all_boxes[j] = all_boxes[j][keep, :]
                all_feats[j] = all_feats[j][keep, :]
    _t['misc'].toc()

    # print 'im_detect in {:.3f}s {:.3f}s' \
    # .format(_t['im_detect'].average_time,
    # _t['misc'].average_time)

    boxes = get_boxes(sess, net, im, all_feats, all_boxes)
    return boxes


def extract_imfea(sess, net, img):
    # im = cv2.imread(path)
    # im_orig = im.astype(np.float32, copy=True)
    # resized image first
    resized_im = cv2.resize(img, (224, 224))
    resized_im -= cfg.PIXEL_MEANS
    fea = net.extract_fc7(sess, [resized_im])
    return np.squeeze(fea)


if __name__ == '__main__':
    pass
